package com.lw.leetcode.tree.c;

import com.lw.test.util.Utils;

import java.util.Arrays;

/**
 * Created with IntelliJ IDEA.
 * tree
 * 2179. 统计数组中好三元组数目
 *
 * @author liw
 * @version 1.0
 * @date 2022/11/24 11:28
 */
public class GoodTriplets {

    public static void main(String[] args) {
        GoodTriplets test = new GoodTriplets();

        // 4
//        int[] nums1 = {4, 0, 1, 3, 2};
//        int[] nums2 = {4, 1, 0, 2, 3};

        // 1
//        int[] nums1 = {2, 0, 1, 3};
//        int[] nums2 = {0, 1, 2, 3};

        // 11
//        int[] nums1 = {1, 6, 2, 8, 5, 0, 4, 7, 3, 9};
//        int[] nums2 = {0, 9, 6, 3, 5, 8, 1, 4, 2, 7};


        int length = 100000;
//        int[] nums1 = Utils.getArrN(length, 0, "C:\\lw\\myword\\a.txt");
//        int[] nums2 = Utils.getArrN(length, 0, "C:\\lw\\myword\\b.txt");
        int[] nums1 = Utils.getArrN(length);
        int[] nums2 = Utils.getArrN(length);
        String str = Arrays.toString(nums1);
        str += "\n";
        str +=  Arrays.toString(nums2);
        Utils.createFile(str, "C:\\lw\\myword\\a.txt");

        long l = test.goodTriplets(nums1, nums2);
        System.out.println(l);
    }

    public long goodTriplets(int[] nums1, int[] nums2) {
        int length = nums2.length;
        int[] arr = new int[length];
        for (int i = 0; i < length; i++) {
            arr[nums2[i]] = i;
        }
        for (int i = 0; i < length; i++) {
            nums1[i] = arr[nums1[i]];
        }
        long sum = 0;
        Node root = new Node(0, length - 1, 0, 0);
        for (int i = 0; i < length; i++) {
            int t = nums1[i];
            Node node = find(t, root);
            sum += node.sum;
            add(t, node.count, root);
        }
        return sum;
    }

    private void add(int t, long sum, Node node) {
        int st = node.st;
        int end = node.end;
        node.count += 1;
        node.sum += sum;
        if (st == end) {
            return;
        }
        int m = st + ((end - st) >> 1);
        if (t > m) {
            if (node.right == null) {
                node.right = new Node(m + 1, node.end, 0, 0);
            }
            add(t,  sum, node.right);
        } else {
            if (node.left == null) {
                node.left = new Node(node.st, m, 0, 0);
            }
            add(t,  sum, node.left);
        }
    }

    private Node find(int t, Node node) {
        if (node == null) {
            return new Node(0, 0, 0, 0);
        }
        int st = node.st;
        int end = node.end;
        if (st == end) {
            return new Node(0, 0, node.count, node.sum);
        }
        int m = st + ((end - st) >> 1);
        if (t > m) {
            Node right = find(t, node.right);
            Node left = node.left;
            if (left != null) {
                right.count += left.count;
                right.sum += left.sum;
            }
            return right;
        }
        return find(t, node.left);
    }

    private static class Node {
        public Node(int st, int end, long count, long sum) {
            this.st = st;
            this.end = end;
            this.count = count;
            this.sum = sum;
        }

        private int st;
        private int end;
        private long count;
        private long sum;
        private Node left;
        private Node right;
    }

}
